"""
Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR)
Copyright(c) 2023 lyuwenyu. All Rights Reserved.
"""

import functools
import importlib
import inspect
from collections import defaultdict
from typing import Any, Dict, List, Optional

GLOBAL_CONFIG = defaultdict(dict)


def register(dct: Any = GLOBAL_CONFIG, name=None, force=False):
    """
    dct:
        if dct is Dict, register foo into dct as key-value pair
        if dct is Clas, register as modules attibute
    force
        whether force register.
    """

    def decorator(foo):
        register_name = foo.__name__ if name is None else name
        if not force:
            if inspect.isclass(dct):
                assert not hasattr(dct, foo.__name__), f"module {dct.__name__} has {foo.__name__}"
            else:
                assert foo.__name__ not in dct, f"{foo.__name__} has been already registered"

        if inspect.isfunction(foo):

            @functools.wraps(foo)
            def wrap_func(*args, **kwargs):
                return foo(*args, **kwargs)

            if isinstance(dct, dict):
                dct[foo.__name__] = wrap_func
            elif inspect.isclass(dct):
                setattr(dct, foo.__name__, wrap_func)
            else:
                raise AttributeError("")
            return wrap_func

        elif inspect.isclass(foo):
            dct[register_name] = extract_schema(foo)

        else:
            raise ValueError(f"Do not support {type(foo)} register")

        return foo

    return decorator


def extract_schema(module: type):
    """
    Args:
        module (type),
    Return:
        Dict,
    """
    argspec = inspect.getfullargspec(module.__init__)
    arg_names = [arg for arg in argspec.args if arg != "self"]
    num_defualts = len(argspec.defaults) if argspec.defaults is not None else 0
    num_requires = len(arg_names) - num_defualts

    schame = dict()
    schame["_name"] = module.__name__
    schame["_pymodule"] = importlib.import_module(module.__module__)
    schame["_inject"] = getattr(module, "__inject__", [])
    schame["_share"] = getattr(module, "__share__", [])
    schame["_kwargs"] = {}
    for i, name in enumerate(arg_names):
        if name in schame["_share"]:
            assert i >= num_requires, "share config must have default value."
            value = argspec.defaults[i - num_requires]

        elif i >= num_requires:
            value = argspec.defaults[i - num_requires]

        else:
            value = None

        schame[name] = value
        schame["_kwargs"][name] = value

    return schame


def create(type_or_name, global_cfg=GLOBAL_CONFIG, **kwargs):
    """ """
    assert type(type_or_name) in (type, str), "create should be modules or name."

    name = type_or_name if isinstance(type_or_name, str) else type_or_name.__name__

    if name in global_cfg:
        if hasattr(global_cfg[name], "__dict__"):
            return global_cfg[name]
    else:
        raise ValueError("The module {} is not registered".format(name))

    cfg = global_cfg[name]

    if isinstance(cfg, dict) and "type" in cfg:
        _cfg: dict = global_cfg[cfg["type"]]
        # clean args
        _keys = [k for k in _cfg.keys() if not k.startswith("_")]
        for _arg in _keys:
            del _cfg[_arg]
        _cfg.update(_cfg["_kwargs"])  # restore default args
        _cfg.update(cfg)  # load config args
        _cfg.update(kwargs)  # TODO recive extra kwargs
        name = _cfg.pop("type")  # pop extra key `type` (from cfg)

        return create(name, global_cfg)

    module = getattr(cfg["_pymodule"], name)
    module_kwargs = {}
    module_kwargs.update(cfg)

    # shared var
    for k in cfg["_share"]:
        if k in global_cfg:
            module_kwargs[k] = global_cfg[k]
        else:
            module_kwargs[k] = cfg[k]

    # inject
    for k in cfg["_inject"]:
        _k = cfg[k]

        if _k is None:
            continue

        if isinstance(_k, str):
            if _k not in global_cfg:
                raise ValueError(f"Missing inject config of {_k}.")

            _cfg = global_cfg[_k]

            if isinstance(_cfg, dict):
                module_kwargs[k] = create(_cfg["_name"], global_cfg)
            else:
                module_kwargs[k] = _cfg

        elif isinstance(_k, dict):
            if "type" not in _k.keys():
                raise ValueError("Missing inject for `type` style.")

            _type = str(_k["type"])
            if _type not in global_cfg:
                raise ValueError(f"Missing {_type} in inspect stage.")

            # TODO
            _cfg: dict = global_cfg[_type]
            # clean args
            _keys = [k for k in _cfg.keys() if not k.startswith("_")]
            for _arg in _keys:
                del _cfg[_arg]
            _cfg.update(_cfg["_kwargs"])  # restore default values
            _cfg.update(_k)  # load config args
            name = _cfg.pop("type")  # pop extra key (`type` from _k)
            module_kwargs[k] = create(name, global_cfg)

        else:
            raise ValueError(f"Inject does not support {_k}")

    # TODO hard code
    module_kwargs = {k: v for k, v in module_kwargs.items() if not k.startswith("_")}

    # TODO for **kwargs
    # extra_args = set(module_kwargs.keys()) - set(arg_names)
    # if len(extra_args) > 0:
    #     raise RuntimeError(f'Error: unknown args {extra_args} for {module}')

    return module(**module_kwargs)
